import asyncio
import datetime

from google import genai
from google.genai import types
from loguru import logger

from app.config import GOOGLE_GENAI_API_KEY
from llm_generation.models.base import BaseModel
from llm_generation.models.data_structure import StreamingDelta


class Gemini(BaseModel):
    def __init__(self, model_name: str = "gemini-2.5-pro-preview-03-25"):
        super().__init__(model_name)

    async def generate_response(
        self, user_prompt: str, conversation: list = None, **kwargs
    ) -> str:
        gemini_client = genai.Client(api_key=GOOGLE_GENAI_API_KEY)
        system_instruction = None

        # convert conversation history
        conversation_history = []
        for message in conversation or []:
            if message["role"] == "system":
                system_instruction = [
                    types.Part.from_text(text=message["content"]),
                ]
            elif message["role"] == "user":
                conversation_history.append(
                    types.Content(
                        role="user",
                        parts=[
                            types.Part.from_text(text=message["content"]),
                        ],
                    )
                )
            elif message["role"] == "assistant":
                conversation_history.append(
                    types.Content(
                        role="model",
                        parts=[
                            types.Part.from_text(text=message["content"]),
                        ],
                    )
                )
            else:
                logger.error(f"Unknown role: {message['role']}")

        # Add user prompt
        conversation_history.append(
            types.Content(
                role="user",
                parts=[
                    types.Part.from_text(text=user_prompt),
                ],
            )
        )

        # Add google tools
        tools = [types.Tool(google_search=types.GoogleSearch())]

        generate_content_config = types.GenerateContentConfig(
            tools=tools,
            response_mime_type="text/plain",
            system_instruction=system_instruction,
        )

        #
        # openai_client = AsyncClient(api_key=OPENAI_API_KEY)
        # conversation = conversation or []
        #
        # response = await openai_client.chat.completions.create(
        #     model=self.model_name,
        #     messages=conversation + [{"role": "user", "content": user_prompt}],
        #     **kwargs
        # )
        # # Streaming response
        content = ""
        cache_part_list = []
        if "stream" in kwargs and bool(kwargs["stream"]):
            if self.streaming_callback is None:
                logger.warning(
                    "No streaming callback is set, skipping callback function"
                )
            async for chunk in await gemini_client.aio.models.generate_content_stream(
                model=self.model_name,
                contents=conversation_history,
                config=generate_content_config,
            ):
                delta = ""
                if (
                    chunk.candidates
                    and chunk.candidates[0]
                    and chunk.candidates[0].content
                ):
                    parts_list = chunk.candidates[0].content.parts
                    for part in parts_list:
                        cache_part_list.append(part)
                        if part.text:
                            delta += part.text
                    if self.streaming_callback:
                        stream_delta = StreamingDelta(content=delta, is_thinking=False)
                        if asyncio.iscoroutinefunction(self.streaming_callback):
                            await self.streaming_callback(content, stream_delta)
                        else:
                            self.streaming_callback(content, stream_delta)

                    if delta:
                        content += delta
            #
            # async for chunk in response:
            #     delta = chunk.choices[0].delta
            #
            #     # Call the streaming callback function
            #     if self.streaming_callback:
            #         await self.streaming_callback(content, delta)
            #
            #     # Append the content
            #     if delta.content:
            #         content += delta.content
        else:
            raw_response = await gemini_client.aio.models.generate_content(
                model=self.model_name,
                contents=conversation_history,
                config=generate_content_config,
            )
            parts_list = raw_response.candidates[0].content.parts
            for part in parts_list:
                if part.text:
                    content += part.text

        return content


async def main():
    async def log(content, delta):
        print(datetime.datetime.now())
        print(delta)

    gemini = Gemini()
    gemini.set_streaming_callback(log)
    response = await gemini.generate_response(
        "what's going on the crypto market?", stream=True
    )
    print("=" * 20)
    print(response)


if __name__ == "__main__":
    asyncio.run(main())
